"""
此模版专用于策略验证回测用
实现功能有：
1）多支股票同时回测；
2）可以指定具体股票代码
注：
1）回测数据，主要用处理后的通达信数据（1、通达信数据处理是指日线级别；
2、通达信分钟线数据提取，是指1分钟和5分钟的数据）
2）用想把通达信下载后数据，加入各种指标数据，可以借用dataprocessing.py
模块进行提前处理
3）可以借用策略筛选脚本，先进行股票筛选，然后再进行在线模拟回测
"""
import datetime  # 用于datetime对象操作
import os.path  # 用于管理路径
import backtrader as bt  # 引入backtrader框架
import pandas as pd

stk_num = 3        # 回测股票数目

""" ***********************************(策略部份)******************************** """


class MyStrategy(bt.Strategy):       # 创建策略
    # 可配置策略参数
    params = dict(
        pfast=5,             # 短期均线周期
        pslow=60,            # 长期均线周期
        poneplot=False,      # 是否打印到同一张图
        pstake=1000          # 单笔交易股票数目
    )

    def __init__(self):
        self.inds = dict()
        for i, d in enumerate(self.datas):                #  对于多支股票数据，需要使用enumerate拆包数据

            self.inds[d] = dict()
            self.inds[d]['sma1'] = bt.ind.SMA(d.close, period=self.p.pfast)             # 短期均线
            self.inds[d]['sma2'] = bt.ind.SMA(d.close, period=self.p.pslow)             # 长期均线
            self.inds[d]['cross'] = bt.ind.CrossOver(self.inds[d]['sma1'],
                                                     self.inds[d]['sma2'], plot=False)  # 交叉信号
            # 跳过第一只股票data，第一只股票data作为主图数据
            if i > 0:
                if self.p.poneplot:
                    d.plotinfo.plotmaster = self.datas[0]

    def next(self):
        for i, d in enumerate(self.datas):                   #  对于多支股票数据，需要使用enumerate拆包数据
            # dt, dn = self.datetime.date(), d._name         # 获取时间及股票代码
            pos = self.getposition(d).size
            if not pos:                                      # 不在场内，则可以买入
                if self.inds[d]['cross'] > 0:  # 如果金叉
                    self.buy(data=d, size=self.p.pstake)     # 买买买
            elif self.inds[d]['cross'] < 0:                  # 在场内，且死叉
                self.close(data=d)                           # 卖卖卖

    # ------------( 下面代码无需调整，仅做交易记录显示用 )-------------------
    def notify_order(self, order):
        if order.status in [order.Submitted, order.Accepted]:
            # 提交给代理或者由代理接收的买/卖订单 - 不做操作
            return
        # 检查订单是否执行完毕
        # 注意：如果没有足够资金，代理会拒绝订单
        if order.status in [order.Completed]:
            if order.isbuy():
                print("\033[35m日期：{}，股票代码：{}, 买入价格:{:.2f},Cost:{:.2f},股票数量:{:.2f},手续费:{:.2f}\033[0m"
                      .format(self.datas[0].datetime.date(0), order.data._name,
                              order.executed.price, order.executed.value, order.size, order.executed.comm))
                # 卖
            else:
                print("\033[32m日期：{}，股票代码：{},卖出价格:{:.2f},Cost:{:.2f},股票数量:{:.2f},手续费:{:.2f}\033[0m"
                      .format(self.datas[0].datetime.date(0), order.data._name,
                              order.executed.price, order.executed.price*order.size, order.size, order.executed.comm))

            # self.bar_executed = len(self)
            # 如果指令取消/交易失败, 报告结果

    # def stop(self):
    #     print("\033[32m Ending value:{:.2f}\033[0m" .format(self.broker.getvalue()))

    def notify_trade(self, trade):
        if not trade.isclosed:
            return
        # 显示交易的毛利率和净利润
        print("\033[31m 交易日期:{} 股票代码：{} 毛利润:{:.2f}, 净利润:{:.2f}\033[0m"
              .format(self.datetime.date(), trade.data._name, trade.pnl, trade.pnlcomm))
        # self.log('OPERATION PROFIT, GROSS %.2f, NET %.2f' %
        #          (trade.pnl, trade.pnlcomm))


""" ***********************************(数据导入)******************************** """
origin_data_path = '../data/day/'                     #  存放处理后通达信数据的文件夹 地址
start_datatime = datetime.datetime(2020, 8, 3)       #  导入的数据开始的有效时间也是回测起始时间
end_datatime = datetime.datetime(2021, 12, 3)         #  导入的数据结束的有效时间也是回测结束时间

# 创建cerebro
cerebro = bt.Cerebro()
# 读入股票代码
stk_code_num = []                                 #  用于提取股票代码的列表
stk_code_file = os.listdir(origin_data_path)    #  ../data/tdx/day/数据相对于脚本文件的地址，读取文件名
for stk_code in stk_code_file:
    index1 = stk_code.rfind(".")
    newname_code = stk_code[:index1]              # 去掉文件.csv尾缀 如剩下为”sh600000“
    # stk_path = '../data/tdx/day/'+stk_code
    # stk_data = pd.read_csv(stk_path, encoding='gbk')
    # stk_data['code']=newname_code
    stk_code_num.append(newname_code)             # 处理后的文件名，添加到列表中
stk_pools = pd.DataFrame(stk_code_num)            # 把列表转成DateFrame格式文件
if stk_num > stk_pools.shape[0]:                  # 判断自定义回测股票数 与有的股票数 对比
    print('股票数目不能大于%d' % stk_pools.shape[0])
    exit()
stk_pools =pd.read_csv('code.csv')       # **可以指定具体回测代码清单清单'code.csv'，在脚本文件目录下；如不需，则需要注释掉**#
stk_pools = pd.DataFrame(['sz000065', 'sz000058', 'sz000068'])     # 可以在脚本中，直接指定代码进行回测 不需要，要进行注释掉
for i in range(stk_num):
    list_star_code = 0                                # 可以设置股票选取数据段 如从数据表中100开始
    stk_code = stk_pools.iloc[i+list_star_code, 0]
    # 读入数据地址
    datapath = origin_data_path + stk_code + '.csv'
    # 创建价格数据
    data = bt.feeds.GenericCSVData(
        dataname=datapath,
        fromdate=start_datatime,
        todate=end_datatime,
        nullvalue=0.0,
        dtformat='%Y-%m-%d',
        datetime=0,
        open=1,
        high=2,
        low=3,
        close=4,
        volume=5,
        openinterest=-1
    )
    # 在Cerebro中添加股票数据
    cerebro.adddata(data, name=stk_code)

""" ***********************************(资金手续费设置等)******************************** """
# 设置启动资金
cerebro.broker.setcash(100000.0)
origin_cash = cerebro.broker.getvalue()
# 设置交易单位大小
# cerebro.addsizer(bt.sizers.FixedSize, stake=100)
# 设置佣金为千分之一
cerebro.broker.setcommission(commission=0.001)

cerebro.addstrategy(MyStrategy, poneplot=False)  # 添加策略
cerebro.run()  # 遍历所有数据
# 打印最后结果
print('初始资金：%.2f,最终资金: %.2f,总盈利：%.2f' % (origin_cash, cerebro.broker.getvalue(),
                                         cerebro.broker.getvalue()-origin_cash))

""" ***********************************(绘图)******************************** """
# 绘图时间段调整位置
start_plot = datetime.datetime(2020, 12, 1)
end_plot = datetime.datetime(2021, 12, 1)
# 绘图
cerebro.plot(style="candle", start=start_plot, end=end_plot,
             barup='red', bardown='green')

# cerebro.plot(style="candle", start=start_plot, end=end_plot,
#              barup = 'red', bardown = 'green',volume = False)  # 绘图
